{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "524620c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp common._base_auto"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15392f6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "12fa25a4",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization\n",
    "\n",
    "> Machine Learning forecasting methods are defined by many hyperparameters that control their behavior, with effects ranging from their speed and memory requirements to their predictive performance. For a long time, manual hyperparameter tuning prevailed. This approach is time-consuming, **automated hyperparameter optimization** methods have been introduced, proving more efficient than manual tuning, grid search, and random search.<br><br> The `BaseAuto` class offers shared API connections to hyperparameter optimization algorithms like [Optuna](https://docs.ray.io/en/latest/tune/examples/bayesopt_example.html), [HyperOpt](https://docs.ray.io/en/latest/tune/examples/hyperopt_example.html), [Dragonfly](https://docs.ray.io/en/latest/tune/examples/dragonfly_example.html) among others through `ray`, which gives you access to grid search, bayesian optimization and other state-of-the-art tools like hyperband.<br><br>Comprehending the impacts of hyperparameters is still a precious skill, as it can help guide the design of informed hyperparameter spaces that are faster to explore automatically."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e37fd67c",
   "metadata": {},
   "source": [
    "![Figure 1. Example of dataset split (left), validation (yellow) and test (orange). The hyperparameter optimization guiding signal is obtained from the validation set.](imgs_models/data_splits.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2508f7a9-1433-4ad8-8f2f-0078c6ed6c3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastcore.test import test_eq\n",
    "from nbdev.showdoc import show_doc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44065066-e72a-431f-938f-1528adef9fe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import warnings\n",
    "from copy import deepcopy\n",
    "from os import cpu_count\n",
    "\n",
    "import torch\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from ray import air, tune\n",
    "from ray.tune.integration.pytorch_lightning import TuneReportCallback\n",
    "from ray.tune.search.basic_variant import BasicVariantGenerator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45cecbda-68c8-4426-a186-9a2a94dcc54e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class MockTrial:\n",
    "    def suggest_int(*args, **kwargs):\n",
    "        return 'int'\n",
    "    def suggest_categorical(*args, **kwargs):\n",
    "        return 'categorical'\n",
    "    def suggest_uniform(*args, **kwargs):\n",
    "        return 'uniform'\n",
    "    def suggest_loguniform(*args, **kwargs):\n",
    "        return 'loguniform'\n",
    "    def suggest_float(*args, **kwargs):\n",
    "        if 'log' in kwargs:\n",
    "            return 'quantized_log'\n",
    "        elif 'step' in kwargs:\n",
    "            return 'quantized_loguniform'\n",
    "        return 'float'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c253583-8239-4abe-8a04-0c0ba635d8a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class BaseAuto(pl.LightningModule):\n",
    "    \"\"\"\n",
    "    Class for Automatic Hyperparameter Optimization, it builds on top of `ray` to \n",
    "    give access to a wide variety of hyperparameter optimization tools ranging \n",
    "    from classic grid search, to Bayesian optimization and HyperBand algorithm.\n",
    "\n",
    "    The validation loss to be optimized is defined by the `config['loss']` dictionary\n",
    "    value, the config also contains the rest of the hyperparameter search space.\n",
    "\n",
    "    It is important to note that the success of this hyperparameter optimization\n",
    "    heavily relies on a strong correlation between the validation and test periods.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    cls_model : PyTorch/PyTorchLightning model\n",
    "        See `neuralforecast.models` [collection here](https://nixtla.github.io/neuralforecast/models.html).\n",
    "    h : int\n",
    "        Forecast horizon\n",
    "    loss : PyTorch module\n",
    "        Instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).\n",
    "    valid_loss : PyTorch module\n",
    "        Instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).\n",
    "    config : dict or callable\n",
    "        Dictionary with ray.tune defined search space or function that takes an optuna trial and returns a configuration dict.\n",
    "    search_alg : ray.tune.search variant or optuna.sampler\n",
    "        For ray see https://docs.ray.io/en/latest/tune/api_docs/suggestion.html\n",
    "        For optuna see https://optuna.readthedocs.io/en/stable/reference/samplers/index.html.\n",
    "    num_samples : int\n",
    "        Number of hyperparameter optimization steps/samples.\n",
    "    cpus : int (default=os.cpu_count())\n",
    "        Number of cpus to use during optimization. Only used with ray tune.\n",
    "    gpus : int (default=torch.cuda.device_count())\n",
    "        Number of gpus to use during optimization, default all available. Only used with ray tune.\n",
    "    refit_with_val : bool\n",
    "        Refit of best model should preserve val_size.\n",
    "    verbose : bool\n",
    "        Track progress.\n",
    "    alias : str, optional (default=None)\n",
    "        Custom name of the model.\n",
    "    backend : str (default='ray')\n",
    "        Backend to use for searching the hyperparameter space, can be either 'ray' or 'optuna'.\n",
    "    callbacks : list of callable, optional (default=None)\n",
    "        List of functions to call during the optimization process.\n",
    "        ray reference: https://docs.ray.io/en/latest/tune/tutorials/tune-metrics.html\n",
    "        optuna reference: https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/007_optuna_callback.html\n",
    "    \"\"\"\n",
    "    def __init__(self, \n",
    "                 cls_model,\n",
    "                 h,\n",
    "                 loss,\n",
    "                 valid_loss,\n",
    "                 config, \n",
    "                 search_alg=BasicVariantGenerator(random_state=1),\n",
    "                 num_samples=10,\n",
    "                 cpus=cpu_count(),\n",
    "                 gpus=torch.cuda.device_count(),\n",
    "                 refit_with_val=False,\n",
    "                 verbose=False,\n",
    "                 alias=None,\n",
    "                 backend='ray',\n",
    "                 callbacks=None,\n",
    "                ):\n",
    "        super(BaseAuto, self).__init__()\n",
    "        with warnings.catch_warnings(record=False):\n",
    "            warnings.filterwarnings('ignore')\n",
    "            # the following line issues a warning about the loss attribute being saved\n",
    "            # but we do want to save it\n",
    "            self.save_hyperparameters() # Allows instantiation from a checkpoint from class\n",
    "\n",
    "        if backend == 'ray':\n",
    "            if not isinstance(config, dict):\n",
    "                raise ValueError(\n",
    "                    \"You have to provide a dict as `config` when using `backend='ray'`\"\n",
    "                )\n",
    "            config_base = deepcopy(config)\n",
    "        elif backend == 'optuna':\n",
    "            if not callable(config):\n",
    "                raise ValueError(\n",
    "                    \"You have to provide a function that takes a trial and returns a dict as `config` when using `backend='optuna'`\"\n",
    "                )\n",
    "            # extract constant values from the config fn for validations\n",
    "            config_base = config(MockTrial())\n",
    "        else:\n",
    "            raise ValueError(f\"Unknown backend {backend}. The supported backends are 'ray' and 'optuna'.\")\n",
    "        if config_base.get('h', None) is not None:\n",
    "            raise Exception(\"Please use `h` init argument instead of `config['h']`.\")\n",
    "        if config_base.get('loss', None) is not None:\n",
    "            raise Exception(\"Please use `loss` init argument instead of `config['loss']`.\")\n",
    "        if config_base.get('valid_loss', None) is not None:\n",
    "            raise Exception(\"Please use `valid_loss` init argument instead of `config['valid_loss']`.\")\n",
    "        # This attribute helps to protect \n",
    "        # model and datasets interactions protections\n",
    "        if 'early_stop_patience_steps' in config_base.keys():\n",
    "            self.early_stop_patience_steps = 1\n",
    "        else:\n",
    "            self.early_stop_patience_steps = -1\n",
    "\n",
    "        if callable(config):\n",
    "            # reset config_base here to save params to override in the config fn\n",
    "            config_base = {}\n",
    "\n",
    "        # Add losses to config and protect valid_loss default\n",
    "        config_base['h'] = h\n",
    "        config_base['loss'] = loss\n",
    "        if valid_loss is None:\n",
    "            valid_loss = loss\n",
    "        config_base['valid_loss'] = valid_loss\n",
    "\n",
    "        if isinstance(config, dict):\n",
    "            self.config = config_base            \n",
    "        else:\n",
    "            def config_f(trial):\n",
    "                return {**config(trial), **config_base}\n",
    "            self.config = config_f            \n",
    "        \n",
    "        self.h = h\n",
    "        self.cls_model = cls_model\n",
    "        self.loss = loss\n",
    "        self.valid_loss = valid_loss\n",
    "\n",
    "        self.num_samples = num_samples\n",
    "        self.search_alg = search_alg\n",
    "        self.cpus = cpus\n",
    "        self.gpus = gpus\n",
    "        self.refit_with_val = refit_with_val or self.early_stop_patience_steps > 0\n",
    "        self.verbose = verbose\n",
    "        self.alias = alias\n",
    "        self.backend = backend\n",
    "        self.callbacks = callbacks\n",
    "\n",
    "        # Base Class attributes\n",
    "        self.SAMPLING_TYPE = cls_model.SAMPLING_TYPE\n",
    "\n",
    "    def __repr__(self):\n",
    "        return type(self).__name__ if self.alias is None else self.alias\n",
    "    \n",
    "    def _train_tune(self, config_step, cls_model, dataset, val_size, test_size):\n",
    "        \"\"\" BaseAuto._train_tune\n",
    "\n",
    "        Internal function that instantiates a NF class model, then automatically\n",
    "        explores the validation loss (ptl/val_loss) on which the hyperparameter \n",
    "        exploration is based.\n",
    "\n",
    "        **Parameters:**<br>\n",
    "        `config_step`: Dict, initialization parameters of a NF model.<br>\n",
    "        `cls_model`: NeuralForecast model class, yet to be instantiated.<br>\n",
    "        `dataset`: NeuralForecast dataset, to fit the model.<br>\n",
    "        `val_size`: int, validation size for temporal cross-validation.<br>\n",
    "        `test_size`: int, test size for temporal cross-validation.<br>\n",
    "        \"\"\"\n",
    "        metrics = {\"loss\": \"ptl/val_loss\", \"train_loss\": \"train_loss\"}\n",
    "        callbacks = [TuneReportCallback(metrics, on=\"validation_end\")]\n",
    "        if 'callbacks' in config_step.keys():\n",
    "            callbacks.extend(config_step['callbacks'])\n",
    "        config_step = {**config_step, **{'callbacks': callbacks}}\n",
    "\n",
    "        # Protect dtypes from tune samplers\n",
    "        if 'batch_size' in config_step.keys():\n",
    "            config_step['batch_size'] = int(config_step['batch_size'])\n",
    "        if 'windows_batch_size' in config_step.keys():\n",
    "            config_step['windows_batch_size'] = int(config_step['windows_batch_size'])\n",
    "\n",
    "        # Tune session receives validation signal\n",
    "        # from the specialized PL TuneReportCallback\n",
    "        _ = self._fit_model(cls_model=cls_model,\n",
    "                                config=config_step,\n",
    "                                dataset=dataset,\n",
    "                                val_size=val_size,\n",
    "                                test_size=test_size)\n",
    "\n",
    "    def _tune_model(self, cls_model, dataset, val_size, test_size,\n",
    "                cpus, gpus, verbose, num_samples, search_alg, config):\n",
    "        train_fn_with_parameters = tune.with_parameters(\n",
    "            self._train_tune,\n",
    "            cls_model=cls_model,\n",
    "            dataset=dataset,\n",
    "            val_size=val_size,\n",
    "            test_size=test_size,\n",
    "        )\n",
    "\n",
    "        # Device\n",
    "        if gpus > 0:\n",
    "            device_dict = {'gpu':gpus}\n",
    "        else:\n",
    "            device_dict = {'cpu':cpus}\n",
    "\n",
    "        # on Windows, prevent long trial directory names\n",
    "        import platform\n",
    "        trial_dirname_creator=(lambda trial: f\"{trial.trainable_name}_{trial.trial_id}\") if platform.system() == 'Windows' else None\n",
    "\n",
    "        tuner = tune.Tuner(\n",
    "            tune.with_resources(train_fn_with_parameters, device_dict),\n",
    "            run_config=air.RunConfig(callbacks=self.callbacks, verbose=verbose),\n",
    "            tune_config=tune.TuneConfig(\n",
    "                metric=\"loss\",\n",
    "                mode=\"min\",\n",
    "                num_samples=num_samples, \n",
    "                search_alg=search_alg,\n",
    "                trial_dirname_creator=trial_dirname_creator,\n",
    "            ),\n",
    "            param_space=config,\n",
    "        )\n",
    "        results = tuner.fit()\n",
    "        return results\n",
    "\n",
    "    @staticmethod\n",
    "    def _ray_config_to_optuna(ray_config):\n",
    "        def optuna_config(trial):\n",
    "            out = {}\n",
    "            for k, v in ray_config.items():\n",
    "                if hasattr(v, 'sampler'):\n",
    "                    sampler = v.sampler\n",
    "                    if isinstance(sampler, tune.search.sample.Integer.default_sampler_cls):\n",
    "                        v = trial.suggest_int(k, v.lower, v.upper)\n",
    "                    elif isinstance(sampler, tune.search.sample.Categorical.default_sampler_cls):\n",
    "                        v = trial.suggest_categorical(k, v.categories)                    \n",
    "                    elif isinstance(sampler, tune.search.sample.Uniform):\n",
    "                        v = trial.suggest_uniform(k, v.lower, v.upper)\n",
    "                    elif isinstance(sampler, tune.search.sample.LogUniform):\n",
    "                        v = trial.suggest_loguniform(k, v.lower, v.upper)\n",
    "                    elif isinstance(sampler, tune.search.sample.Quantized):\n",
    "                        if isinstance(sampler.get_sampler(), tune.search.sample.Float._LogUniform):\n",
    "                            v = trial.suggest_float(k, v.lower, v.upper, log=True)\n",
    "                        elif isinstance(sampler.get_sampler(), tune.search.sample.Float._Uniform):\n",
    "                            v = trial.suggest_float(k, v.lower, v.upper, step=sampler.q)\n",
    "                    else:\n",
    "                        raise ValueError(f\"Couldn't translate {type(v)} to optuna.\")\n",
    "                out[k] = v\n",
    "            return out\n",
    "        return optuna_config\n",
    "\n",
    "    def _optuna_tune_model(\n",
    "        self,\n",
    "        cls_model,\n",
    "        dataset,\n",
    "        val_size,\n",
    "        test_size,\n",
    "        verbose,\n",
    "        num_samples,\n",
    "        search_alg,\n",
    "        config,\n",
    "        distributed_config,\n",
    "    ):\n",
    "        import optuna\n",
    "\n",
    "        def objective(trial):\n",
    "            user_cfg = config(trial)\n",
    "            cfg = deepcopy(user_cfg)\n",
    "            model = self._fit_model(\n",
    "                cls_model=cls_model,\n",
    "                config=cfg,\n",
    "                dataset=dataset,\n",
    "                val_size=val_size,\n",
    "                test_size=test_size,\n",
    "                distributed_config=distributed_config,\n",
    "            )\n",
    "            trial.set_user_attr('ALL_PARAMS', user_cfg)\n",
    "            metrics = model.metrics\n",
    "            trial.set_user_attr('METRICS', {\n",
    "                \"loss\": metrics[\"ptl/val_loss\"],\n",
    "                \"train_loss\": metrics[\"train_loss\"],\n",
    "            })\n",
    "            return trial.user_attrs['METRICS']['loss']\n",
    "\n",
    "        if isinstance(search_alg, optuna.samplers.BaseSampler):\n",
    "            sampler = search_alg\n",
    "        else:\n",
    "            sampler = None\n",
    "\n",
    "        study = optuna.create_study(sampler=sampler, direction='minimize')\n",
    "        study.optimize(\n",
    "            objective,\n",
    "            n_trials=num_samples,\n",
    "            show_progress_bar=verbose,\n",
    "            callbacks=self.callbacks,\n",
    "        )\n",
    "        return study\n",
    "\n",
    "    def _fit_model(self, cls_model, config,\n",
    "                   dataset, val_size, test_size, distributed_config=None):\n",
    "        model = cls_model(**config)\n",
    "        model = model.fit(\n",
    "            dataset,\n",
    "            val_size=val_size, \n",
    "            test_size=test_size,\n",
    "            distributed_config=distributed_config,\n",
    "        )\n",
    "        return model\n",
    "\n",
    "    def fit(self, dataset, val_size=0, test_size=0, random_seed=None, distributed_config=None):\n",
    "        \"\"\" BaseAuto.fit\n",
    "\n",
    "        Perform the hyperparameter optimization as specified by the BaseAuto configuration \n",
    "        dictionary `config`.\n",
    "\n",
    "        The optimization is performed on the `TimeSeriesDataset` using temporal cross validation with \n",
    "        the validation set that sequentially precedes the test set.\n",
    "\n",
    "        **Parameters:**<br>\n",
    "        `dataset`: NeuralForecast's `TimeSeriesDataset` see details [here](https://nixtla.github.io/neuralforecast/tsdataset.html)<br>\n",
    "        `val_size`: int, size of temporal validation set (needs to be bigger than 0).<br>\n",
    "        `test_size`: int, size of temporal test set (default 0).<br>\n",
    "        `random_seed`: int=None, random_seed for hyperparameter exploration algorithms, not yet implemented.<br>\n",
    "        **Returns:**<br>\n",
    "        `self`: fitted instance of `BaseAuto` with best hyperparameters and results<br>.\n",
    "        \"\"\"\n",
    "        #we need val_size > 0 to perform\n",
    "        #hyperparameter selection.\n",
    "        search_alg = deepcopy(self.search_alg)\n",
    "        val_size = val_size if val_size > 0 else self.h\n",
    "        if self.backend == 'ray':\n",
    "            if distributed_config is not None:\n",
    "                raise ValueError('distributed training is not supported for the ray backend.')\n",
    "            results = self._tune_model(\n",
    "                cls_model=self.cls_model,\n",
    "                dataset=dataset,\n",
    "                val_size=val_size,\n",
    "                test_size=test_size, \n",
    "                cpus=self.cpus,\n",
    "                gpus=self.gpus,\n",
    "                verbose=self.verbose,\n",
    "                num_samples=self.num_samples, \n",
    "                search_alg=search_alg, \n",
    "                config=self.config,\n",
    "            )            \n",
    "            best_config = results.get_best_result().config            \n",
    "        else:\n",
    "            results = self._optuna_tune_model(\n",
    "                cls_model=self.cls_model,\n",
    "                dataset=dataset,\n",
    "                val_size=val_size, \n",
    "                test_size=test_size, \n",
    "                verbose=self.verbose,\n",
    "                num_samples=self.num_samples, \n",
    "                search_alg=search_alg, \n",
    "                config=self.config,\n",
    "                distributed_config=distributed_config,\n",
    "            )\n",
    "            best_config = results.best_trial.user_attrs['ALL_PARAMS']\n",
    "        self.model = self._fit_model(\n",
    "            cls_model=self.cls_model,\n",
    "            config=best_config,\n",
    "            dataset=dataset,\n",
    "            val_size=val_size * self.refit_with_val,\n",
    "            test_size=test_size,\n",
    "            distributed_config=distributed_config,\n",
    "        )\n",
    "        self.results = results\n",
    "\n",
    "         # Added attributes for compatibility with NeuralForecast core\n",
    "        self.futr_exog_list = self.model.futr_exog_list\n",
    "        self.hist_exog_list = self.model.hist_exog_list\n",
    "        self.stat_exog_list = self.model.stat_exog_list\n",
    "        return self\n",
    "\n",
    "    def predict(self, dataset, step_size=1, **data_kwargs):\n",
    "        \"\"\" BaseAuto.predict\n",
    "\n",
    "        Predictions of the best performing model on validation.\n",
    "\n",
    "        **Parameters:**<br>\n",
    "        `dataset`: NeuralForecast's `TimeSeriesDataset` see details [here](https://nixtla.github.io/neuralforecast/tsdataset.html)<br>\n",
    "        `step_size`: int, steps between sequential predictions, (default 1).<br>\n",
    "        `**data_kwarg`: additional parameters for the dataset module.<br>\n",
    "        `random_seed`: int=None, random_seed for hyperparameter exploration algorithms (not implemented).<br>\n",
    "        **Returns:**<br>\n",
    "        `y_hat`: numpy predictions of the `NeuralForecast` model.<br>\n",
    "        \"\"\"\n",
    "        return self.model.predict(dataset=dataset, \n",
    "                                  step_size=step_size, **data_kwargs)\n",
    "\n",
    "    def set_test_size(self, test_size):\n",
    "        self.model.set_test_size(test_size)\n",
    "\n",
    "    def get_test_size(self):\n",
    "        return self.model.test_size\n",
    "    \n",
    "    def save(self, path):\n",
    "        \"\"\" BaseAuto.save\n",
    "\n",
    "        Save the fitted model to disk.\n",
    "\n",
    "        **Parameters:**<br>\n",
    "        `path`: str, path to save the model.<br>\n",
    "        \"\"\"\n",
    "        self.model.save(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2376ed06",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(BaseAuto, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "623ebb06",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(BaseAuto.fit, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69d3c1ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(BaseAuto.predict, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbfd4e8f-2565-4f85-b615-7329a1ae3f43",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import logging\n",
    "import warnings\n",
    "\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "421db156-4ee6-420f-ac9e-f0ddc9781841",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1e776fb-fa7e-49c6-afd2-b30891c83a73",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import optuna\n",
    "import pandas as pd\n",
    "from neuralforecast.models.mlp import MLP\n",
    "from neuralforecast.utils import AirPassengersDF as Y_df\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset\n",
    "from neuralforecast.losses.numpy import mae\n",
    "from neuralforecast.losses.pytorch import MAE, MSE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c26739d-c405-4700-a833-79c3a0fec497",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "Y_train_df = Y_df[Y_df.ds<='1959-12-31'] # 132 train\n",
    "Y_test_df = Y_df[Y_df.ds>'1959-12-31']   # 12 test\n",
    "\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_train_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88148bbe-b4c1-41c3-8ce1-4f7695161d99",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RayLogLossesCallback(tune.Callback):\n",
    "    def on_trial_complete(self, iteration, trials, trial, **info):\n",
    "        result = trial.last_result\n",
    "        print(40 * '-' + 'Trial finished' + 40 * '-')\n",
    "        print(f'Train loss: {result[\"train_loss\"]:.2f}. Valid loss: {result[\"loss\"]:.2f}')\n",
    "        print(80 * '-')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae8912d7-9128-42ab-a581-5f63b6ea34eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"hidden_size\": tune.choice([512]),\n",
    "    \"num_layers\": tune.choice([3, 4]),\n",
    "    \"input_size\": 12,\n",
    "    \"max_steps\": 10,\n",
    "    \"val_check_steps\": 5\n",
    "}\n",
    "auto = BaseAuto(h=12, loss=MAE(), valid_loss=MSE(), cls_model=MLP, config=config, num_samples=2, cpus=1, gpus=0, callbacks=[RayLogLossesCallback()])\n",
    "auto.fit(dataset=dataset)\n",
    "y_hat = auto.predict(dataset=dataset)\n",
    "assert mae(Y_test_df['y'].values, y_hat[:, 0]) < 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63d46d13-f0d0-4bc0-aba2-bd094a9a78c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def config_f(trial):\n",
    "    return {\n",
    "        \"hidden_size\": trial.suggest_categorical('hidden_size', [512]),\n",
    "        \"num_layers\": trial.suggest_categorical('num_layers', [3, 4]),\n",
    "        \"input_size\": 12,\n",
    "        \"max_steps\": 10,\n",
    "        \"val_check_steps\": 5\n",
    "    }\n",
    "\n",
    "class OptunaLogLossesCallback:\n",
    "    def __call__(self, study, trial):\n",
    "        metrics = trial.user_attrs['METRICS']\n",
    "        print(40 * '-' + 'Trial finished' + 40 * '-')\n",
    "        print(f'Train loss: {metrics[\"train_loss\"]:.2f}. Valid loss: {metrics[\"loss\"]:.2f}')\n",
    "        print(80 * '-')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d979d9df-3a8d-4aab-aaa9-5b66067aef26",
   "metadata": {},
   "outputs": [],
   "source": [
    "auto2 = BaseAuto(h=12, loss=MAE(), valid_loss=MSE(), cls_model=MLP, config=config_f, search_alg=optuna.samplers.RandomSampler(), num_samples=2, backend='optuna', callbacks=[OptunaLogLossesCallback()])\n",
    "auto2.fit(dataset=dataset)\n",
    "assert isinstance(auto2.results, optuna.Study)\n",
    "y_hat2 = auto2.predict(dataset=dataset)\n",
    "assert mae(Y_test_df['y'].values, y_hat2[:, 0]) < 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66ad2eec-dd93-4bc4-ae19-5df4199577be",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "Y_test_df['AutoMLP'] = y_hat\n",
    "\n",
    "pd.concat([Y_train_df, Y_test_df]).drop('unique_id', axis=1).set_index('ds').plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "463d4dc0-b25a-4ce6-9172-5690dc979f0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Unit tests to guarantee that losses are correctly instantiated\n",
    "import pandas as pd\n",
    "from neuralforecast.models.mlp import MLP\n",
    "from neuralforecast.utils import AirPassengersDF as Y_df\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset\n",
    "from neuralforecast.losses.pytorch import MAE, MSE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "882c8331-440a-4758-a56c-07a78c0b1603",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Unit tests to guarantee that losses are correctly instantiated\n",
    "Y_train_df = Y_df[Y_df.ds<='1959-12-31'] # 132 train\n",
    "Y_test_df = Y_df[Y_df.ds>'1959-12-31']   # 12 test\n",
    "\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_train_df)\n",
    "config = {\n",
    "    \"hidden_size\": tune.choice([512]),\n",
    "    \"num_layers\": tune.choice([3, 4]),\n",
    "    \"input_size\": 12,\n",
    "    \"max_steps\": 1,\n",
    "    \"val_check_steps\": 1\n",
    "}\n",
    "\n",
    "# Test instantiation\n",
    "auto = BaseAuto(h=12, loss=MAE(), valid_loss=MSE(), \n",
    "                cls_model=MLP, config=config, num_samples=2, cpus=1, gpus=0)\n",
    "test_eq(str(type(auto.loss)), \"<class 'neuralforecast.losses.pytorch.MAE'>\")\n",
    "test_eq(str(type(auto.valid_loss)), \"<class 'neuralforecast.losses.pytorch.MSE'>\")\n",
    "\n",
    "# Test validation default\n",
    "auto = BaseAuto(h=12, loss=MSE(), valid_loss=None,\n",
    "                cls_model=MLP, config=config, num_samples=2, cpus=1, gpus=0)\n",
    "test_eq(str(type(auto.loss)), \"<class 'neuralforecast.losses.pytorch.MSE'>\")\n",
    "test_eq(str(type(auto.valid_loss)), \"<class 'neuralforecast.losses.pytorch.MSE'>\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3c8e2d46",
   "metadata": {},
   "source": [
    "### References\n",
    "- [James Bergstra, Remi Bardenet, Yoshua Bengio, and Balazs Kegl (2011). \"Algorithms for Hyper-Parameter Optimization\". In: Advances in Neural Information Processing Systems. url: https://proceedings.neurips.cc/paper/2011/file/86e8f7ab32cfd12577bc2619bc635690-Paper.pdf](https://proceedings.neurips.cc/paper/2011/file/86e8f7ab32cfd12577bc2619bc635690-Paper.pdf)\n",
    "- [Kirthevasan Kandasamy, Karun Raju Vysyaraju, Willie Neiswanger, Biswajit Paria, Christopher R. Collins, Jeff Schneider, Barnabas Poczos, Eric P. Xing (2019). \"Tuning Hyperparameters without Grad Students: Scalable and Robust Bayesian Optimisation with Dragonfly\". Journal of Machine Learning Research. url: https://arxiv.org/abs/1903.06694](https://arxiv.org/abs/1903.06694)\n",
    "- [Lisha Li, Kevin Jamieson, Giulia DeSalvo, Afshin Rostamizadeh, Ameet Talwalkar (2016). \"Hyperband: A Novel Bandit-Based Approach to Hyperparameter Optimization\". Journal of Machine Learning Research. url: https://arxiv.org/abs/1603.06560](https://arxiv.org/abs/1603.06560)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "267cbf1e",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
